{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "My2AZpV_RuTs"
      },
      "source": [
        "# Sharding\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sharding.ipynb)\n",
        "\n",
        "Sharding for Gemma models. This example run inference on Gemma 27B, on a TPU v2, using 8 devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "94CVV9ZxKVDO"
      },
      "outputs": [],
      "source": [
        "!pip install -q gemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "executionInfo": {
          "elapsed": 3414,
          "status": "ok",
          "timestamp": 1741527331115,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "PXBc1hRKRuTt"
      },
      "outputs": [],
      "source": [
        "# Common imports\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "# Gemma imports\n",
        "from gemma import gm\n",
        "from kauldron import kd"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i_DEVehe3v5Y"
      },
      "source": [
        "For this colab, make sure to be connected to the TPU kernel by selecting `Change runtime type` \u003e `v2-8 TPU` to access multiple accelerators. Jax should display multiple devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "executionInfo": {
          "elapsed": 1686,
          "status": "ok",
          "timestamp": 1741527337025,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "S_E1StHpif8C",
        "outputId": "bdc915ed-e7c1-4aac-9ee5-baa0a8c81fad"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "8"
            ]
          },
          "execution_count": 2,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "jax.device_count()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XzEj3PYvX1Sm"
      },
      "source": [
        "Load the model, and the params. Here we load the 27B model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "executionInfo": {
          "elapsed": 10101,
          "status": "ok",
          "timestamp": 1741527355077,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "xizXTfAEXu-m"
      },
      "outputs": [],
      "source": [
        "model = gm.nn.Gemma3_27B()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4lol2Xhhp-4a"
      },
      "source": [
        "When restoring the weights, you can pass `sharding=` parameter to the `gm.ckpts.load_params`. Here we use a naive `kd.sharding.FSDPSharding` heuristic."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 20991,
          "status": "ok",
          "timestamp": 1741527376235,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "8ozH7eJ3p_HU"
      },
      "outputs": [],
      "source": [
        "params = gm.ckpts.load_params(\n",
        "    gm.ckpts.CheckpointPath.GEMMA3_27B_IT,\n",
        "    sharding=kd.sharding.FSDPSharding(),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AH0eWFWJaiNk"
      },
      "source": [
        "## Sampling\n",
        "\n",
        "Test the sharding using the `gm.text.ChatSampler`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 14292,
          "status": "ok",
          "timestamp": 1741527483123,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "6R5J42EiZtkC",
        "outputId": "29195151-5a5c-453b-9f33-c2178ce5b5de"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Okay, here's a fascinating and relatively unknown fact about the brain:\n",
            "\n",
            "**Your brain actively \"cleans itself\" during sleep with a system called the glymphatic system, and this cleaning process is *much* more efficient when you sleep on your side.**\n",
            "\n",
            "Here's the breakdown:\n",
            "\n",
            "* **The Glymphatic System:** For a long time, it was thought the brain didn't have a traditional lymphatic system (which clears waste from the body).  But in 2012, researchers discovered the glymphatic system. It's essentially a brain-wide waste clearance system that uses cerebrospinal fluid (CSF) to flush out metabolic waste products that build up during waking hours – things like amyloid-beta, a protein associated with Alzheimer's disease.\n",
            "* **How it Works:** CSF flows *along* arteries into the brain tissue and then drains out along veins. This flow is significantly enhanced during sleep.\n",
            "* **Side Sleeping is Key:**  Research (particularly studies using MRI scans) has shown that sleeping on your side – *especially* the left side – is the most effective position for clearing waste from the brain. This is because the lateral position allows gravity to assist the flow of CSF and facilitates the clearance of interstitial fluid (the fluid between brain cells) and waste products.  Sleeping on your back is *less* effective, and sleeping on your stomach is the *least* effective.\n",
            "\n",
            "**Why is this relatively unknown?** The glymphatic system is a relatively recent discovery, and research is still ongoing.  It's also a complex system to study.  \n",
            "\n",
            "**Source/Further Reading:**\n",
            "\n",
            "*   **Oregon Health \u0026 Science University (OHSU) - The Brain's Cleaning System:** [https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep](https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep)\n",
            "*   **ScienceAlert - Scientists Discover How You Can Optimize Your Brain's Cleaning System While You Sleep:** [https://www.sciencealert.com/scientists-discover-how-you-can-optimize-your-brain-s-cleaning-system-while-you-sleep](https://www.sciencealert.com/scientists-discover-how-you-can\n"
          ]
        }
      ],
      "source": [
        "sampler = gm.text.ChatSampler(model=model, params=params)\n",
        "\n",
        "out = sampler.chat('Tell me an unknown interesting biology fact about the brain.')\n",
        "print(out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BexrfJgQIoY3"
      },
      "source": [
        "Even though we've learned something new, the model still halucinated urls, still showing limitations of current system."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TPf01271RuTs"
      },
      "source": [
        "## Calling the model directly\n",
        "\n",
        "It's also possbile to directly call the model. For this, the input has first to be manually encoded."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LBnxVNZFgCyG"
      },
      "outputs": [],
      "source": [
        "tokenizer = gm.text.Gemma3Tokenizer()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c1PCIsJSb-tk"
      },
      "outputs": [],
      "source": [
        "prompt = tokenizer.encode('My name is', add_bos=True)  # /!\\ Don't forget to add the BOS token\n",
        "prompt = jnp.asarray(prompt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yk_76bPSf6X6"
      },
      "source": [
        "When using sharding, input also has to be sharded. Here, we have a single prompt, so we use `kd.sharding.REPLICATED` sharding so each device get a copy of the prompt.\n",
        "\n",
        "During training, usually the prompts will be batched and padded, then sharded using `kd.sharding.FIRST_DIM`, so the prompts are distributed across each devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lmk1jP9NguyM"
      },
      "outputs": [],
      "source": [
        "prompt = kd.sharding.with_sharding_constraint(prompt, kd.sharding.REPLICATED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-8LmwpxTfsd_"
      },
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 8878,
          "status": "ok",
          "timestamp": 1740077558479,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "G3Kbo9hgRuTt",
        "outputId": "00d3db7d-65d5-48e1-97eb-232dc2aabeef"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "' Mary'"
            ]
          },
          "execution_count": 54,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Run the model\n",
        "out = model.apply(\n",
        "    {'params': params},\n",
        "    tokens=prompt,\n",
        "    return_last_only=True,  # Only predict the last token\n",
        ")\n",
        "\n",
        "\n",
        "# Sample a token from the predicted logits\n",
        "next_token = jax.random.categorical(\n",
        "    jax.random.key(1),\n",
        "    out.logits\n",
        ")\n",
        "tokenizer.decode(next_token)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TNb88k3EF5gu"
      },
      "source": [
        "You can also display the next token probability."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 542
        },
        "executionInfo": {
          "elapsed": 8373,
          "status": "ok",
          "timestamp": 1740077578530,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "mIkOdE9dF45s",
        "outputId": "ae004303-f1c5-47ca-9601-14a52ef5080f"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\u003chtml\u003e\n",
              "\u003chead\u003e\u003cmeta charset=\"utf-8\" /\u003e\u003c/head\u003e\n",
              "\u003cbody\u003e\n",
              "    \u003cdiv\u003e            \u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"\u003e\u003c/script\u003e\u003cscript type=\"text/javascript\"\u003eif (window.MathJax \u0026\u0026 window.MathJax.Hub \u0026\u0026 window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\u003c/script\u003e                \u003cscript type=\"text/javascript\"\u003ewindow.PlotlyConfig = {MathJaxConfig: 'local'};\u003c/script\u003e\n",
              "        \u003cscript charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\"\u003e\u003c/script\u003e                \u003cdiv id=\"ce1cd022-2a06-4653-9b94-08eee7e33c97\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"\u003e\u003c/div\u003e            \u003cscript type=\"text/javascript\"\u003e                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"ce1cd022-2a06-4653-9b94-08eee7e33c97\")) {                    Plotly.newPlot(                        \"ce1cd022-2a06-4653-9b94-08eee7e33c97\",                        [{\"alignmentgroup\":\"True\",\"hovertemplate\":\"x=%{x}\\u003cbr\\u003ey=%{y}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\",\"legendgroup\":\"\",\"marker\":{\"color\":\"#636efa\",\"pattern\":{\"shape\":\"\"}},\"name\":\"\",\"offsetgroup\":\"\",\"orientation\":\"v\",\"showlegend\":false,\"textposition\":\"auto\",\"x\":[\"' Sarah'\",\"' ['\",\"' Dr'\",\"' Emily'\",\"' Michael'\",\"' **'\",\"' David'\",\"' Jessica'\",\"' John'\",\"' Alex'\",\"' Anya'\",\"' Emma'\",\"' Daniel'\",\"' Anna'\",\"' Amelia'\",\"' Laura'\",\"' Ashley'\",\"' K'\",\"' A'\",\"' Olivia'\",\"' Maya'\",\"' James'\",\"' Chloe'\",\"' Jennifer'\",\"' Elizabeth'\",\"' Maria'\",\"' Samantha'\",\"' Katie'\",\"' Melissa'\",\"' Rachel'\"],\"xaxis\":\"x\",\"y\":[0.06689453125,0.035888671875,0.0245361328125,0.023193359375,0.01806640625,0.015869140625,0.01318359375,0.0123291015625,0.01025390625,0.00970458984375,0.00848388671875,0.00848388671875,0.00799560546875,0.00750732421875,0.006622314453125,0.0062255859375,0.005859375,0.005859375,0.005859375,0.004852294921875,0.004852294921875,0.004852294921875,0.004547119140625,0.004547119140625,0.0042724609375,0.0042724609375,0.0040283203125,0.0040283203125,0.0037689208984375,0.0037689208984375],\"yaxis\":\"y\",\"type\":\"bar\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"title\":{\"text\":\"Tokens\"}},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"title\":{\"text\":\"Probability\"}},\"legend\":{\"tracegroupgap\":0},\"margin\":{\"t\":60},\"barmode\":\"relative\",\"title\":{\"text\":\"Probability Distribution of Tokens\"}},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('ce1cd022-2a06-4653-9b94-08eee7e33c97');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            \u003c/script\u003e        \u003c/div\u003e\n",
              "\u003c/body\u003e\n",
              "\u003c/html\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tokenizer.plot_logits(out.logits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDtnoRhCuk9d"
      },
      "source": [
        "## Training\n",
        "\n",
        "To use sharding during training, simply set the `sharding=` attribute of the trainer, like:\n",
        "\n",
        "```python\n",
        "trainer = kd.train.Trainer(\n",
        "    ...,\n",
        "    sharding=kd.sharding.ShardingStrategy(\n",
        "        params=kd.sharding.FSDPSharding(),\n",
        "    ),\n",
        "    ...,\n",
        ")\n",
        "```\n",
        "\n",
        "See a full example at: https://github.com/google-deepmind/gemma/tree/main/examples/sharding.py"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "last_runtime": {},
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
